Skip to content

Commit fd4b160

Browse files
Use JAX's default device instead of jax.devices()[0], if set.
PiperOrigin-RevId: 702515221
1 parent fcf0b6d commit fd4b160

File tree

5 files changed

+9
-4
lines changed

5 files changed

+9
-4
lines changed

jax/_src/dispatch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ def _device_put_sharding_impl(x, aval, device, copy):
440440
return pxla.batched_device_put(aval, SingleDeviceSharding(device), [x],
441441
[device])
442442

443-
sh = SingleDeviceSharding(pxla._get_default_device()
443+
sh = SingleDeviceSharding(pxla.get_default_device()
444444
if device is None else device)
445445
return _DeferredShardArg(x, sh, aval, device is not None, copy)
446446

jax/_src/interpreters/pxla.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1710,7 +1710,7 @@ class DeviceAssignmentMismatchError(Exception):
17101710
]
17111711

17121712

1713-
def _get_default_device() -> xc.Device:
1713+
def get_default_device() -> xc.Device:
17141714
if isinstance(config.default_device.value, str):
17151715
return xb.get_backend(config.default_device.value).local_devices()[0]
17161716
else:
@@ -1749,7 +1749,7 @@ def _get_and_check_device_assignment(
17491749
if first_sharding_info is None and devices:
17501750
final_device_assignment = devices
17511751
elif first_sharding_info is None:
1752-
final_device_assignment = (_get_default_device(),)
1752+
final_device_assignment = (get_default_device(),)
17531753
else:
17541754
final_device_assignment = first_sharding_info[0] # type: ignore
17551755
return xb.get_device_backend(final_device_assignment[0]), final_device_assignment

jax/_src/pallas/mosaic/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ py_library(
124124
"//jax:pallas",
125125
"//jax:util",
126126
"//jax/_src/pallas",
127+
"//jax/extend:backend",
127128
] + py_deps("numpy"),
128129
)
129130

jax/_src/pallas/mosaic/pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from jax._src.pallas.mosaic import core as tpu_core
3434
from jax._src.pallas.mosaic import primitives as tpu_primitives
3535
from jax.experimental import pallas as pl
36+
from jax.extend.backend import get_default_device
3637
import jax.numpy as jnp
3738
import numpy as np
3839

@@ -75,7 +76,7 @@ def add_leaves(i, x):
7576

7677
@jax_util.cache(trace_context_in_key=False)
7778
def _get_tpu_generation() -> int:
78-
kind = jax.devices()[0].device_kind
79+
kind = get_default_device().device_kind
7980
if kind.endswith(' lite'):
8081
kind = kind[:-len(' lite')]
8182
assert kind[:5] == "TPU v", kind

jax/extend/backend.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,6 @@
2424
get_backend as get_backend,
2525
register_backend_factory as register_backend_factory,
2626
)
27+
from jax._src.interpreters.pxla import (
28+
get_default_device as get_default_device
29+
)

0 commit comments

Comments
 (0)