Skip to content

Commit afa518a

Browse files
committed
Allow setting default_device with platform names.
1 parent 5615028 commit afa518a

File tree

3 files changed

+18
-11
lines changed

3 files changed

+18
-11
lines changed

jax/_src/config.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1561,20 +1561,21 @@ def _update_default_device_thread_local(val):
15611561

15621562

15631563
def _validate_default_device(val):
1564-
if val is not None and not isinstance(val, xla_client.Device):
1564+
if (val is not None and
1565+
not isinstance(val, xla_client.Device) and
1566+
val not in ['cpu', 'gpu', 'tpu']):
15651567
# TODO(skyewm): this is a workaround for non-PJRT Device types. Remove when
15661568
# all JAX backends use a single C++ device interface.
15671569
if 'Device' in str(type(val)):
15681570
logger.info(
15691571
'Allowing non-`xla_client.Device` default device: %s, type: %s',
15701572
repr(val), type(val))
15711573
return
1572-
raise ValueError('jax.default_device must be passed a Device object (e.g. '
1573-
f"`jax.devices('cpu')[0]`), got: {val!r}")
1574+
raise ValueError('jax.default_device must be passed either a Device object (e.g. '
1575+
f"`jax.devices('cpu')[0]`) or a platform name string like 'cpu' or 'gpu'"
1576+
f", got: {val!r}")
15741577

15751578

1576-
# TODO(skye): default_device only accepts devices for now. Make it work with
1577-
# platform names as well (e.g. "cpu" to mean the same as jax.devices("cpu")[0]).
15781579
default_device = string_or_object_state(
15791580
name='jax_default_device',
15801581
default=None,

jax/_src/interpreters/pxla.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1710,7 +1710,10 @@ class DeviceAssignmentMismatchError(Exception):
17101710

17111711

17121712
def _get_default_device() -> xc.Device:
1713-
return config.default_device.value or xb.local_devices()[0]
1713+
if isinstance(config.default_device.value, str):
1714+
return xb.get_backend(config.default_device.value).local_devices()[0]
1715+
else:
1716+
return config.default_device.value or xb.local_devices()[0]
17141717

17151718

17161719
def _get_and_check_device_assignment(
@@ -1742,6 +1745,7 @@ def _get_and_check_device_assignment(
17421745
raise DeviceAssignmentMismatchError([
17431746
DeviceAssignmentMismatch(devices, MismatchType.CONTEXT_DEVICES, None),
17441747
DeviceAssignmentMismatch(arr_device_assignment, s_type, source_info)])
1748+
17451749
if first_sharding_info is None and devices:
17461750
final_device_assignment = devices
17471751
elif first_sharding_info is None:
@@ -2190,6 +2194,7 @@ def lower_sharding_computation(
21902194
assert len(out_shardings) == len(out_layouts) == len(global_out_avals), (
21912195
len(out_shardings), len(out_layouts), len(global_out_avals))
21922196

2197+
21932198
devices_from_context = (None if context_mesh is None or context_mesh.empty
21942199
else context_mesh._flat_devices_tuple)
21952200
# Device assignment across all inputs, outputs and shardings inside jaxpr

tests/api_test.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -287,13 +287,14 @@ def test_jit_default_device(self, module):
287287
self.assertEqual(f(sticky).devices(), system_default_devices)
288288
self.assertEqual(f(1).devices(), system_default_devices)
289289

290-
# TODO(skye): make this work!
291290
def test_jit_default_platform(self):
292-
with self.assertRaisesWithLiteralMatch(
293-
ValueError, "jax.default_device must be passed a Device object "
294-
"(e.g. `jax.devices('cpu')[0]`), got: 'cpu'"):
295291
with jax.default_device("cpu"):
296-
jax.jit(lambda x: x + 1)(1)
292+
result = jax.jit(lambda x: x + 1)(1)
293+
self.assertEqual(result.device.platform, "cpu")
294+
295+
result = jax.jit(lambda x: x + 1)(1)
296+
self.assertEqual(result.device.platform, jax.default_backend())
297+
297298

298299
def test_complex_support(self):
299300
self.assertEqual(jit(lambda x: x + 1)(1 + 1j), 2 + 1j)

0 commit comments

Comments
 (0)