Skip to content

Commit b6e4b93

Browse files
pschuhGoogle-ML-Automation
authored andcommitted
Add jaxlib_extension_version guard against explicit copying
in jax.device_put. PiperOrigin-RevId: 744838237
1 parent e1b0572 commit b6e4b93

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

jax/_src/dispatch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from jax._src.interpreters import pxla
4545
from jax._src.interpreters import xla
4646
from jax._src.layout import DeviceLocalLayout, Layout
47+
from jax._src.lib import jaxlib_extension_version
4748
from jax._src.lib import xla_client as xc
4849
from jax._src.mesh import AbstractMesh, Mesh
4950
from jax._src.monitoring import record_event_duration_secs, record_event_time_span
@@ -495,7 +496,7 @@ def _device_put_sharding_impl(x, aval, device, copy):
495496
return _DeferredShardArg(x, x.sharding, aval, x.committed, copy)
496497
elif is_single_device_sharding(x.sharding):
497498
device = x.sharding._device_assignment[0] if device is None else device
498-
if copy == CopySemantics.COPY:
499+
if copy == CopySemantics.COPY and jaxlib_extension_version >= 327:
499500
return xc.batched_device_put(aval, SingleDeviceSharding(device), [x],
500501
[device], True, True)
501502
return pxla.batched_device_put(aval, SingleDeviceSharding(device), [x],

tests/pjit_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1401,7 +1401,7 @@ def test_zero_literal_equality(self):
14011401
self.assertIn("stablehlo.constant dense<-0.000000e+00>", ir)
14021402

14031403
def test_device_put_copy_donate(self):
1404-
if jaxlib_extension_version < 323:
1404+
if jaxlib_extension_version < 327:
14051405
raise unittest.SkipTest("Copy not supported in device put.")
14061406
x = np.arange(1000)
14071407
y = jax.device_put(x, device=jax.devices()[0], may_alias=False, donate=False)

0 commit comments

Comments
 (0)