Skip to content

Commit 4d33b9c

Browse files
authored
Merge branch 'apple:main' into elastic
2 parents 4426abc + 1c137ff commit 4d33b9c

38 files changed

+1320
-584
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,9 @@ cython_debug/
159159
# Vscode
160160
.vscode/
161161

162+
# Zed
163+
.zed/
164+
162165
# Weights & Biases
163166
wandb/
164167

axlearn/audio/frontend.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from axlearn.audio.frontend_utils import (
1616
WindowType,
17+
cast_for_rfft,
1718
frame,
1819
frame_paddings,
1920
linear_to_log_mel_spectrogram,
@@ -143,14 +144,6 @@ def _fft_dtype(input_dtype: jnp.dtype) -> jnp.dtype:
143144
raise ValueError(f"{input_dtype=} is not supported.")
144145

145146

146-
def _cast_for_rfft(x: Tensor) -> Tensor:
147-
# jnp.fft.rfft input must be float32 or float64.
148-
if x.dtype in (jnp.float32, jnp.float64):
149-
return x
150-
else:
151-
return x.astype(jnp.float32)
152-
153-
154147
class LogMelFrontend(BaseFrontend):
155148
"""Computes Log Mel spectrogram features.
156149
@@ -200,7 +193,7 @@ def __init__(self, cfg: Config, *, parent: Optional[Module]):
200193
if cfg.fft is not None:
201194
self._fft = cfg.fft.set(n=fft_size).instantiate()
202195
else:
203-
self._fft = lambda x: jnp.fft.rfft(_cast_for_rfft(x), n=fft_size)
196+
self._fft = lambda x: jnp.fft.rfft(cast_for_rfft(x), n=fft_size)
204197

205198
spectrogram = maybe_set_config(
206199
cfg.spectrogram,

axlearn/audio/frontend_utils.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
import enum
1212
import math
13-
from functools import partial
1413
from typing import Callable, Union
1514

1615
import jax
@@ -404,8 +403,25 @@ def sharded_fft(n: int, partition_spec: PartitionSpec) -> Callable[[Tensor], Ten
404403
A callable that computes FFT.
405404
"""
406405
return shard_map(
407-
partial(jnp.fft.rfft, n=n),
406+
lambda x: jnp.fft.rfft(cast_for_rfft(x), n=n),
408407
mesh=thread_resources.env.physical_mesh,
409408
in_specs=partition_spec,
410409
out_specs=partition_spec,
411410
)
411+
412+
413+
def cast_for_rfft(x: Tensor) -> Tensor:
414+
"""Casts the input tensor to a valid dtype for jnp.fft.rfft if necessary.
415+
416+
jnp.fft.rfft requires the input to be of dtype float32 or float64.
417+
418+
Args:
419+
x: Input tensor of arbitrary dtype.
420+
421+
Returns:
422+
A tensor of dtype float32 or float64, suitable for jnp.fft.rfft.
423+
"""
424+
if x.dtype in (jnp.float32, jnp.float64):
425+
return x
426+
else:
427+
return x.astype(jnp.float32)

axlearn/audio/frontend_utils_test.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from axlearn.audio import frontend_utils
2424
from axlearn.audio.frontend_utils import (
2525
WindowType,
26+
cast_for_rfft,
2627
frame,
2728
frame_paddings,
2829
linear_to_log_mel_spectrogram,
@@ -393,15 +394,13 @@ def _ref_log_mel_spectrogram(
393394

394395

395396
class ShardedFftTest(TestCase):
397+
@parameterized.parameters(jnp.float32, jnp.bfloat16)
396398
@set_threefry_partitionable(False) # TODO(Luzy): update for threefry_partitionable True
397-
def test_fft(self):
399+
def test_fft(self, dtype):
398400
input_shape = (8, 800, 400)
399401
fft_size = 512
400402
inputs = jax.random.uniform(
401-
jax.random.PRNGKey(123),
402-
shape=input_shape,
403-
minval=-32768.0,
404-
maxval=32768.0,
403+
jax.random.PRNGKey(123), shape=input_shape, minval=-32768.0, maxval=32768.0, dtype=dtype
405404
)
406405
with Mesh(
407406
mesh_utils.create_device_mesh((len(jax.devices()), 1)), ("data", "model")
@@ -414,7 +413,7 @@ def test_fft(self):
414413
fft_fn = jax.jit(
415414
sharded_fft(n=fft_size, partition_spec=PartitionSpec("data", None, None))
416415
)
417-
ref_ffts = jax.jit(jnp.fft.rfft, static_argnames="n")(inputs, n=fft_size)
416+
ref_ffts = jax.jit(jnp.fft.rfft, static_argnames="n")(cast_for_rfft(inputs), n=fft_size)
418417
test_ffts = fft_fn(inputs)
419418

420419
assert_allclose(ref_ffts, test_ffts, rtol=1e-3)

axlearn/cloud/common/bundler.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
copy_blobs,
6363
get_pyproject_version,
6464
parse_kv_flags,
65+
to_bool,
6566
)
6667
from axlearn.common.config import REQUIRED, Configurable, Required, config_class
6768
from axlearn.common.file_system import copy, exists, makedirs
@@ -341,9 +342,17 @@ def from_spec(cls, spec: list[str], *, fv: Optional[flags.FlagValues]) -> Config
341342
cfg: BaseDockerBundler.Config = super().from_spec(spec, fv=fv)
342343
kwargs = parse_kv_flags(spec, delimiter="=")
343344
cache_from = canonicalize_to_list(kwargs.pop("cache_from", None))
345+
skip_bundle = to_bool(kwargs.pop("skip_bundle", False))
346+
allow_dirty = to_bool(kwargs.pop("allow_dirty", False))
344347
# Non-config specs are treated as build args.
345348
build_args = {k: kwargs.pop(k) for k in list(kwargs.keys()) if k not in cfg}
346-
return cfg.set(build_args=build_args, cache_from=cache_from, **kwargs)
349+
return cfg.set(
350+
build_args=build_args,
351+
cache_from=cache_from,
352+
skip_bundle=skip_bundle,
353+
allow_dirty=allow_dirty,
354+
**kwargs,
355+
)
347356

348357
# pylint: disable-next=arguments-renamed
349358
def id(self, tag: str) -> str:

axlearn/cloud/common/event_queue.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,6 @@ def publish(self, event: Event):
203203
try:
204204
# Ensure connection is established before publishing.
205205
if not self._channel or not self._connection:
206-
logging.error("RabbitMQ publisher channel is closed, reconnecting...")
207206
self.connect()
208207

209208
# Setting durable=True ensures that the queue will survive.
@@ -230,25 +229,30 @@ def publish(self, event: Event):
230229
# Only retry on recoverable exceptions.
231230
# AMQPConnectionError is assumed to be related to network issues,
232231
# or temporary unavailable host.
233-
logging.error(
234-
"Failed to publish event: %s. Error: %s. Attempt: %d",
235-
message,
236-
str(e),
237-
attempt,
238-
)
239232
self._handle_publish_error()
240233
attempt += 1
241-
if attempt <= self._num_tries:
234+
if attempt < self._num_tries:
242235
time.sleep(2**attempt)
236+
else:
237+
logging.error(
238+
"Failed to publish event: %s after %d attempts. Error: %s.",
239+
message,
240+
attempt,
241+
str(e),
242+
)
243243
except Exception as e: # pylint: disable=broad-except
244-
# Unknown errors. Don't retry. Log to avoid crashing clients.
245-
logging.error(
246-
"Unknown error. Failed to publish event: %s. Error: %s.", message, str(e)
247-
)
248244
self._handle_publish_error()
249245
attempt += 1
250-
if attempt <= self._num_tries:
246+
if attempt < self._num_tries:
251247
time.sleep(2**attempt)
248+
else:
249+
# Unknown errors. Don't retry. Log to avoid crashing clients.
250+
logging.error(
251+
"Unknown error. Failed to publish event: %s after %d attempts. Error: %s.",
252+
message,
253+
attempt,
254+
str(e),
255+
)
252256

253257
def _handle_publish_error(self):
254258
"""Handle publish errors with retrying on connection issue."""

axlearn/cloud/common/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,19 @@ def merge(base: dict, overrides: dict):
292292
return base
293293

294294

295+
def to_bool(value: Any) -> bool:
296+
"""Converts a string representation of truth to a bool."""
297+
if isinstance(value, bool):
298+
return value
299+
elif isinstance(value, str):
300+
val_lower = value.lower()
301+
if val_lower == "true":
302+
return True
303+
elif val_lower == "false":
304+
return False
305+
raise ValueError(f"Invalid truth value: '{value}'")
306+
307+
295308
_Row = list[Any]
296309

297310

axlearn/cloud/common/utils_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,24 @@ def test_canonicalize(self, v_seq: Sequence[str], v_str: str, v_list: str, delim
222222
def test_merge(self, base, overrides, expected):
223223
self.assertEqual(expected, utils.merge(base, overrides))
224224

225+
@parameterized.parameters(
226+
("true", True),
227+
("True", True),
228+
("false", False),
229+
("False", False),
230+
(True, True),
231+
(False, False),
232+
("yes", ValueError),
233+
(1, ValueError),
234+
)
235+
def test_to_bool(self, value, expected):
236+
if isinstance(expected, type) and issubclass(expected, Exception):
237+
with self.assertRaises(expected):
238+
utils.to_bool(value)
239+
else:
240+
result = utils.to_bool(value)
241+
self.assertEqual(result, expected)
242+
225243
def test_infer_resources(self):
226244
@config_class
227245
class DummyConfig(ConfigBase):

axlearn/cloud/gcp/bundler.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
from axlearn.cloud.common.bundler import main_flags as bundler_main_flags
5959
from axlearn.cloud.common.bundler import register_bundler
6060
from axlearn.cloud.common.docker import registry_from_repo
61-
from axlearn.cloud.common.utils import canonicalize_to_list
61+
from axlearn.cloud.common.utils import canonicalize_to_list, to_bool
6262
from axlearn.cloud.gcp.cloud_build import get_cloud_build_status
6363
from axlearn.cloud.gcp.config import gcp_settings
6464
from axlearn.cloud.gcp.utils import common_flags
@@ -148,9 +148,7 @@ def from_spec(
148148
cfg.project = cfg.project or gcp_settings("project", required=False, fv=fv)
149149
cfg.repo = cfg.repo or gcp_settings("docker_repo", required=False, fv=fv)
150150
cfg.dockerfile = cfg.dockerfile or gcp_settings("default_dockerfile", required=False, fv=fv)
151-
# The value from from_spec is a str and will result in wrong condition.
152-
if isinstance(cfg.is_async, str):
153-
cfg.is_async = cfg.is_async.lower() != "false"
151+
cfg.is_async = to_bool(cfg.is_async)
154152
return cfg
155153

156154
# pylint: disable-next=no-self-use,unused-argument
@@ -227,19 +225,30 @@ def _build_and_push(
227225
print(subprocess.run(cmd, check=True))
228226
return image
229227

230-
def wait_until_finished(self, name: str):
228+
def wait_until_finished(self, name: str, wait_timeout=3600):
231229
"""Waits for async CloudBuild to finish by polling for status.
232230
233231
Is a no-op if `cfg.is_async` is False.
234232
235233
Args:
236234
name: Bundle name.
235+
wait_timeout: Overall timeout in seconds. Defaults to 1 hour.
237236
238237
Raises:
239-
ValueError: If async build failed.
238+
TimeoutError: If the build does not complete within the overall timeout.
239+
ValueError: If the async build fails.
240240
"""
241+
start_time = time.perf_counter()
241242
cfg: CloudBuildBundler.Config = self.config
242243
while cfg.is_async:
244+
elapsed_time = time.perf_counter() - start_time
245+
if elapsed_time > wait_timeout:
246+
timeout_msg = (
247+
f"Timed out waiting for CloudBuild to finish for more than "
248+
f"{wait_timeout} seconds."
249+
)
250+
logging.error(timeout_msg)
251+
raise TimeoutError(timeout_msg)
243252
try:
244253
build_status = get_cloud_build_status(
245254
project_id=cfg.project, image_name=self.id(name), tags=[name]

axlearn/cloud/gcp/bundler_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,3 +180,21 @@ def test_wait_until_finished_retries_with_runtime_error(self):
180180
b = cfg.set(is_async=True).instantiate()
181181
b.wait_until_finished("test-name")
182182
self.assertEqual(2, mock_status.call_count)
183+
184+
def test_wait_until_finished_triggers_timeout(self):
185+
# Tests that we raise a timeout error if wait_until_finished takes more than 1 hr.
186+
cfg = self._get_test_cloud_build_bundler()
187+
188+
with mock.patch("time.perf_counter") as mock_perf_counter:
189+
mock_perf_counter.side_effect = [0, 10, 500, 3601]
190+
191+
with self._mock_status(
192+
None, CloudBuildStatus.PENDING, CloudBuildStatus.PENDING
193+
) as mock_status:
194+
b = cfg.set(is_async=True).instantiate()
195+
with self.assertRaisesRegex(
196+
TimeoutError,
197+
"Timed out waiting for CloudBuild to finish for more than 3600 seconds.",
198+
):
199+
b.wait_until_finished("test-name")
200+
self.assertEqual(2, mock_status.call_count)

0 commit comments

Comments
 (0)