Skip to content

Commit 549973d

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Allow pspec to be passed to device_put if there is a mesh in the surrounding context
PiperOrigin-RevId: 737812111
1 parent f174b00 commit 549973d

File tree

2 files changed

+30
-3
lines changed

2 files changed

+30
-3
lines changed

jax/_src/api.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,9 @@
6767
from jax._src.lib import xla_client as xc
6868
from jax._src.lib import pmap_lib
6969
from jax._src.sharding import Sharding
70-
from jax._src.sharding_impls import PmapSharding, TransferToMemoryKind
70+
from jax._src.mesh import get_concrete_mesh
71+
from jax._src.sharding_impls import (
72+
PmapSharding, TransferToMemoryKind, PartitionSpec as P, NamedSharding)
7173
from jax._src.layout import Layout, AutoLayout
7274
from jax._src.traceback_util import api_boundary
7375
from jax._src import tree_util
@@ -2280,11 +2282,20 @@ def _check_sharding(aval, s):
22802282
(s,), (aval,), ("",), "device_put args", allow_uneven_sharding=False)
22812283
s.shard_shape(aval.shape) # should raise an Error if incompatible
22822284

2285+
def pspec_to_sharding(val):
2286+
if isinstance(val, P):
2287+
mesh = get_concrete_mesh()
2288+
if mesh is None:
2289+
raise ValueError(
2290+
"Please set a mesh via `jax.sharding.use_mesh` if a PartitionSpec is"
2291+
" passed to device_put")
2292+
return NamedSharding(mesh, val)
2293+
return val
22832294

22842295
def device_put(
22852296
x,
2286-
device: None | xc.Device | Sharding | Layout | Any | TransferToMemoryKind = None,
2287-
*, src: None | xc.Device | Sharding | Layout | Any | TransferToMemoryKind = None,
2297+
device: None | xc.Device | Sharding | P | Layout | Any | TransferToMemoryKind = None,
2298+
*, src: None | xc.Device | Sharding | P | Layout | Any | TransferToMemoryKind = None,
22882299
donate: bool | Any = False, may_alias: bool | None | Any = None):
22892300
"""Transfers ``x`` to ``device``.
22902301
@@ -2333,6 +2344,9 @@ def device_put(
23332344
src_flat = flatten_axes("device_put source", treedef, src)
23342345
src_flat = list(map(_infer_src_sharding, src_flat, x_flat))
23352346

2347+
device_flat = map(pspec_to_sharding, device_flat)
2348+
src_flat = map(pspec_to_sharding, src_flat)
2349+
23362350
if isinstance(donate, bool):
23372351
donate_flat = [donate] * len(x_flat)
23382352
else:

tests/pjit_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6138,6 +6138,19 @@ def f(x):
61386138
self.assertDictEqual(out.sharding.mesh._axis_types_dict,
61396139
{AxisType.Auto: ('x',)})
61406140

6141+
@jtu.with_user_mesh((2,), 'x')
6142+
def test_device_put_use_mesh(self, mesh):
6143+
out = jax.device_put(np.arange(8), P('x'))
6144+
self.assertArraysEqual(out, np.arange(8))
6145+
self.assertEqual(out.sharding, NamedSharding(mesh, P('x')))
6146+
6147+
def test_device_put_no_use_mesh_error(self):
6148+
with self.assertRaisesRegex(
6149+
ValueError,
6150+
'Please set a mesh via `jax.sharding.use_mesh` if a PartitionSpec is'
6151+
' passed to device_put'):
6152+
jax.device_put(np.arange(8), P('x'))
6153+
61416154
@jtu.with_user_mesh((2,), 'x')
61426155
def test_inputs_different_context(self, mesh):
61436156
np_inp = np.arange(16).reshape(8, 2)

0 commit comments

Comments
 (0)