Skip to content

Commit cea8176

Browse files
Merge pull request jax-ml#24751 from Stella-S-Yan:feature/default_device_str
PiperOrigin-RevId: 696560063
2 parents 05716b5 + afa518a commit cea8176

File tree

3 files changed

+17
-11
lines changed

3 files changed

+17
-11
lines changed

jax/_src/config.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1570,20 +1570,21 @@ def _update_default_device_thread_local(val):
15701570

15711571

15721572
def _validate_default_device(val):
1573-
if val is not None and not isinstance(val, xla_client.Device):
1573+
if (val is not None and
1574+
not isinstance(val, xla_client.Device) and
1575+
val not in ['cpu', 'gpu', 'tpu']):
15741576
# TODO(skyewm): this is a workaround for non-PJRT Device types. Remove when
15751577
# all JAX backends use a single C++ device interface.
15761578
if 'Device' in str(type(val)):
15771579
logger.info(
15781580
'Allowing non-`xla_client.Device` default device: %s, type: %s',
15791581
repr(val), type(val))
15801582
return
1581-
raise ValueError('jax.default_device must be passed a Device object (e.g. '
1582-
f"`jax.devices('cpu')[0]`), got: {val!r}")
1583+
raise ValueError('jax.default_device must be passed either a Device object (e.g. '
1584+
f"`jax.devices('cpu')[0]`) or a platform name string like 'cpu' or 'gpu'"
1585+
f", got: {val!r}")
15831586

15841587

1585-
# TODO(skye): default_device only accepts devices for now. Make it work with
1586-
# platform names as well (e.g. "cpu" to mean the same as jax.devices("cpu")[0]).
15871588
default_device = string_or_object_state(
15881589
name='jax_default_device',
15891590
default=None,

jax/_src/interpreters/pxla.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1711,7 +1711,10 @@ class DeviceAssignmentMismatchError(Exception):
17111711

17121712

17131713
def _get_default_device() -> xc.Device:
1714-
return config.default_device.value or xb.local_devices()[0]
1714+
if isinstance(config.default_device.value, str):
1715+
return xb.get_backend(config.default_device.value).local_devices()[0]
1716+
else:
1717+
return config.default_device.value or xb.local_devices()[0]
17151718

17161719

17171720
def _get_and_check_device_assignment(

tests/api_test.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -287,13 +287,15 @@ def test_jit_default_device(self, module):
287287
self.assertEqual(f(sticky).devices(), system_default_devices)
288288
self.assertEqual(f(1).devices(), system_default_devices)
289289

290-
# TODO(skye): make this work!
291290
def test_jit_default_platform(self):
292-
with self.assertRaisesWithLiteralMatch(
293-
ValueError, "jax.default_device must be passed a Device object "
294-
"(e.g. `jax.devices('cpu')[0]`), got: 'cpu'"):
295291
with jax.default_device("cpu"):
296-
jax.jit(lambda x: x + 1)(1)
292+
result = jax.jit(lambda x: x + 1)(1)
293+
self.assertEqual(result.device.platform, "cpu")
294+
self.assertEqual(result.device, jax.local_devices(backend="cpu")[0])
295+
296+
result = jax.jit(lambda x: x + 1)(1)
297+
self.assertEqual(result.device.platform, jax.default_backend())
298+
self.assertEqual(result.device, jax.local_devices()[0])
297299

298300
def test_complex_support(self):
299301
self.assertEqual(jit(lambda x: x + 1)(1 + 1j), 2 + 1j)

0 commit comments

Comments
 (0)