File tree Expand file tree Collapse file tree 3 files changed +17
-11
lines changed Expand file tree Collapse file tree 3 files changed +17
-11
lines changed Original file line number Diff line number Diff line change @@ -1570,20 +1570,21 @@ def _update_default_device_thread_local(val):
15701570
15711571
15721572def _validate_default_device (val ):
1573- if val is not None and not isinstance (val , xla_client .Device ):
1573+ if (val is not None and
1574+ not isinstance (val , xla_client .Device ) and
1575+ val not in ['cpu' , 'gpu' , 'tpu' ]):
15741576 # TODO(skyewm): this is a workaround for non-PJRT Device types. Remove when
15751577 # all JAX backends use a single C++ device interface.
15761578 if 'Device' in str (type (val )):
15771579 logger .info (
15781580 'Allowing non-`xla_client.Device` default device: %s, type: %s' ,
15791581 repr (val ), type (val ))
15801582 return
1581- raise ValueError ('jax.default_device must be passed a Device object (e.g. '
1582- f"`jax.devices('cpu')[0]`), got: { val !r} " )
1583+ raise ValueError ('jax.default_device must be passed either a Device object (e.g. '
1584+ f"`jax.devices('cpu')[0]`) or a platform name string like 'cpu' or 'gpu'"
1585+ f", got: { val !r} " )
15831586
15841587
1585- # TODO(skye): default_device only accepts devices for now. Make it work with
1586- # platform names as well (e.g. "cpu" to mean the same as jax.devices("cpu")[0]).
15871588default_device = string_or_object_state (
15881589 name = 'jax_default_device' ,
15891590 default = None ,
Original file line number Diff line number Diff line change @@ -1711,7 +1711,10 @@ class DeviceAssignmentMismatchError(Exception):
17111711
17121712
17131713def _get_default_device () -> xc .Device :
1714- return config .default_device .value or xb .local_devices ()[0 ]
1714+ if isinstance (config .default_device .value , str ):
1715+ return xb .get_backend (config .default_device .value ).local_devices ()[0 ]
1716+ else :
1717+ return config .default_device .value or xb .local_devices ()[0 ]
17151718
17161719
17171720def _get_and_check_device_assignment (
Original file line number Diff line number Diff line change @@ -287,13 +287,15 @@ def test_jit_default_device(self, module):
287287 self .assertEqual (f (sticky ).devices (), system_default_devices )
288288 self .assertEqual (f (1 ).devices (), system_default_devices )
289289
290- # TODO(skye): make this work!
291290 def test_jit_default_platform (self ):
292- with self .assertRaisesWithLiteralMatch (
293- ValueError , "jax.default_device must be passed a Device object "
294- "(e.g. `jax.devices('cpu')[0]`), got: 'cpu'" ):
295291 with jax .default_device ("cpu" ):
296- jax .jit (lambda x : x + 1 )(1 )
292+ result = jax .jit (lambda x : x + 1 )(1 )
293+ self .assertEqual (result .device .platform , "cpu" )
294+ self .assertEqual (result .device , jax .local_devices (backend = "cpu" )[0 ])
295+
296+ result = jax .jit (lambda x : x + 1 )(1 )
297+ self .assertEqual (result .device .platform , jax .default_backend ())
298+ self .assertEqual (result .device , jax .local_devices ()[0 ])
297299
298300 def test_complex_support (self ):
299301 self .assertEqual (jit (lambda x : x + 1 )(1 + 1j ), 2 + 1j )
You can’t perform that action at this time.
0 commit comments