|
67 | 67 | from jax._src.lib import xla_client as xc |
68 | 68 | from jax._src.lib import pmap_lib |
69 | 69 | from jax._src.sharding import Sharding |
70 | | -from jax._src.sharding_impls import PmapSharding, TransferToMemoryKind |
| 70 | +from jax._src.mesh import get_concrete_mesh |
| 71 | +from jax._src.sharding_impls import ( |
| 72 | + PmapSharding, TransferToMemoryKind, PartitionSpec as P, NamedSharding) |
71 | 73 | from jax._src.layout import Layout, AutoLayout |
72 | 74 | from jax._src.traceback_util import api_boundary |
73 | 75 | from jax._src import tree_util |
@@ -2280,11 +2282,20 @@ def _check_sharding(aval, s): |
2280 | 2282 | (s,), (aval,), ("",), "device_put args", allow_uneven_sharding=False) |
2281 | 2283 | s.shard_shape(aval.shape) # should raise an Error if incompatible |
2282 | 2284 |
|
| 2285 | +def pspec_to_sharding(val): |
| 2286 | + if isinstance(val, P): |
| 2287 | + mesh = get_concrete_mesh() |
| 2288 | + if mesh is None: |
| 2289 | + raise ValueError( |
| 2290 | + "Please set a mesh via `jax.sharding.use_mesh` if a PartitionSpec is" |
| 2291 | + " passed to device_put") |
| 2292 | + return NamedSharding(mesh, val) |
| 2293 | + return val |
2283 | 2294 |
|
2284 | 2295 | def device_put( |
2285 | 2296 | x, |
2286 | | - device: None | xc.Device | Sharding | Layout | Any | TransferToMemoryKind = None, |
2287 | | - *, src: None | xc.Device | Sharding | Layout | Any | TransferToMemoryKind = None, |
| 2297 | + device: None | xc.Device | Sharding | P | Layout | Any | TransferToMemoryKind = None, |
| 2298 | + *, src: None | xc.Device | Sharding | P | Layout | Any | TransferToMemoryKind = None, |
2288 | 2299 | donate: bool | Any = False, may_alias: bool | None | Any = None): |
2289 | 2300 | """Transfers ``x`` to ``device``. |
2290 | 2301 |
|
@@ -2333,6 +2344,9 @@ def device_put( |
2333 | 2344 | src_flat = flatten_axes("device_put source", treedef, src) |
2334 | 2345 | src_flat = list(map(_infer_src_sharding, src_flat, x_flat)) |
2335 | 2346 |
|
| 2347 | + device_flat = map(pspec_to_sharding, device_flat) |
| 2348 | + src_flat = map(pspec_to_sharding, src_flat) |
| 2349 | + |
2336 | 2350 | if isinstance(donate, bool): |
2337 | 2351 | donate_flat = [donate] * len(x_flat) |
2338 | 2352 | else: |
|
0 commit comments