Skip to content

Commit 0b970b8

Browse files
authored
Merge pull request #135 from ROCm/ci-upstream-sync-13_1
CI: 11/08/24 upstream sync
2 parents 9afbd23 + ced1e2b commit 0b970b8

33 files changed

+978
-174
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
4343
* {func}`jax.jit` got a new `compiler_options: dict[str, Any]` argument, for
4444
passing compilation options to XLA. For the moment it's undocumented and
4545
may be in flux.
46+
* {func}`jax.tree_util.register_dataclass` now allows metadata fields to be
47+
declared inline via {func}`dataclasses.field`. See the function documentation
48+
for examples.
4649

4750
## jax 0.4.35 (Oct 22, 2024)
4851

docs/persistent_compilation_cache.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
# Persistent compilation cache
22

3-
<!--* freshness: { reviewed: '2024-04-09' } *-->
3+
<!--* freshness: { reviewed: '2024-11-07' } *-->
44

55
JAX has an optional disk cache for compiled programs. If enabled, JAX will
66
store copies of compiled programs on disk, which can save recompilation time
77
when running the same or similar tasks repeatedly.
88

9+
Note: if the compilation cache is not on a local filesystem,
10+
[etils](https://pypi.org/project/etils/) needs to be installed.
11+
12+
```python
13+
pip install etils
14+
```
15+
916
## Usage
1017

1118
### Quick start

jax/_src/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1599,7 +1599,7 @@ def cache_miss(*args, **kwargs):
15991599

16001600
cpp_mapped_f = pmap_lib.pmap(
16011601
fun, cache_miss, static_broadcasted_tuple,
1602-
lambda x, s: pxla.shard_args([s], [None], [x])[0],
1602+
lambda x, s: pxla.shard_args([s], [None], [None], [x])[0],
16031603
pytree_registry=tree_util.default_registry)
16041604
_pmap_cache_clears.add(cpp_mapped_f)
16051605

jax/_src/array.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from jax._src.layout import AutoLayout, DeviceLocalLayout, Layout
4141
from jax._src.lib import xla_client as xc
4242
from jax._src.lib import xla_extension as xe
43+
from jax._src.lib import xla_extension_version
4344
from jax._src.sharding import Sharding
4445
from jax._src.sharding_impls import (
4546
PmapSharding, SingleDeviceSharding, NamedSharding,
@@ -1110,7 +1111,7 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding):
11101111
# Look up all buffers that contain the correct slice of the logical array.
11111112
candidates_list = candidates[hashed_index(idx)]
11121113
if not candidates_list:
1113-
return pxla.shard_args([sharding], [None], [x._value],
1114+
return pxla.shard_args([sharding], [None], [None], [x._value],
11141115
canonicalize=False)[0]
11151116
# Try to find a candidate buffer already on the correct device,
11161117
# otherwise copy one of them.
@@ -1130,11 +1131,13 @@ def _sharding_indices_and_eq(src_sharding, shape, dst_sharding):
11301131
return dst_indices, tuple(src_indices) == tuple(dst_indices)
11311132

11321133

1133-
def _array_shard_arg(xs, shardings, layouts):
1134+
def _array_shard_arg(xs, shardings, layouts, copy_semantics):
11341135
results = []
11351136
batch_xs, batch_devs, batch_shardings, batch_indices = [], [], [], []
1137+
batch_cs = []
11361138

1137-
for i, (x, sharding, layout) in enumerate(safe_zip(xs, shardings, layouts)):
1139+
for i, (x, sharding, layout, cs) in enumerate(
1140+
safe_zip(xs, shardings, layouts, copy_semantics)):
11381141
x._check_if_deleted()
11391142
indices, same_indices = _sharding_indices_and_eq(x.sharding, x.shape, sharding)
11401143
same_layout = (True if layout is None else
@@ -1156,6 +1159,7 @@ def _array_shard_arg(xs, shardings, layouts):
11561159
batch_devs.append(list(devices))
11571160
batch_shardings.append(sharding)
11581161
batch_indices.append(i)
1162+
batch_cs.append(cs)
11591163
# Resharding starts here:
11601164
elif not same_layout:
11611165
results.append(api.device_put(x, Layout(layout, sharding)))
@@ -1165,8 +1169,12 @@ def _array_shard_arg(xs, shardings, layouts):
11651169
results.append(
11661170
shard_sharded_device_array_slow_path(x, devices, indices, sharding))
11671171

1168-
copy_outs = xc.batched_copy_array_to_devices_with_sharding(
1169-
batch_xs, batch_devs, batch_shardings)
1172+
if xla_extension_version >= 296:
1173+
copy_outs = xc.batched_copy_array_to_devices_with_sharding(
1174+
batch_xs, batch_devs, batch_shardings, batch_cs)
1175+
else:
1176+
copy_outs = xc.batched_copy_array_to_devices_with_sharding( # type: ignore
1177+
batch_xs, batch_devs, batch_shardings)
11701178
for i, copy_out in safe_zip(batch_indices, copy_outs):
11711179
assert results[i] is None
11721180
results[i] = copy_out
@@ -1200,8 +1208,9 @@ def _array_local_result_handler(aval, sharding, indices):
12001208

12011209
# Token handlers
12021210

1203-
def _token_shard_arg(xs, shardings, layouts):
1204-
return _array_shard_arg([x._buf for x in xs], shardings, layouts)
1211+
def _token_shard_arg(xs, shardings, layouts, copy_semantics):
1212+
return _array_shard_arg([x._buf for x in xs], shardings, layouts,
1213+
copy_semantics)
12051214
pxla.shard_arg_handlers[core.Token] = _token_shard_arg
12061215

12071216

jax/_src/dispatch.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def get_token_input(
137137
# We only use replicated sharding for the first time when the token for the
138138
# order effect hasn't been created.
139139
s = jax.sharding.GSPMDSharding.get_replicated(devices)
140-
sharded_tok = core.Token(pxla.shard_args([s], [None], [tok])[0])
140+
sharded_tok = core.Token(pxla.shard_args([s], [None], [None], [tok])[0])
141141
self.current_tokens[eff] = sharded_tok
142142
return sharded_tok
143143

@@ -391,6 +391,7 @@ class _DeferredShardArg:
391391
s: Sharding
392392
aval: core.AbstractValue
393393
committed: bool
394+
copy_semantics: CopySemantics
394395

395396
@property
396397
def result_handler(self):
@@ -435,24 +436,27 @@ def _device_put_sharding_impl(x, aval, device, copy):
435436
"device_put's second argument must be a Device or a Sharding which"
436437
f" represents addressable devices, but got {s}. Please pass device or"
437438
" Sharding which represents addressable devices.")
438-
return _DeferredShardArg(x, s, aval, True)
439+
return _DeferredShardArg(x, s, aval, True, copy)
439440

440441
# Only `Device` exists below. `Sharding` instance is handled above.
441442
if isinstance(x, array.ArrayImpl):
442443
if not x.is_fully_addressable:
443444
raise ValueError(
444445
"device_put's first argument must be a fully addressable array, but "
445446
f"got value with devices {x.devices()}")
446-
if device is None and copy == CopySemantics.ALIAS:
447-
return x
447+
if device is None:
448+
if copy == CopySemantics.ALIAS:
449+
return x
450+
else:
451+
return _DeferredShardArg(x, x.sharding, aval, x.committed, copy)
448452
elif is_single_device_sharding(x.sharding):
449453
device = x.sharding._device_assignment[0] if device is None else device
450454
return pxla.batched_device_put(aval, SingleDeviceSharding(device), [x],
451455
[device])
452456

453457
sh = SingleDeviceSharding(pxla._get_default_device()
454458
if device is None else device)
455-
return _DeferredShardArg(x, sh, aval, device is not None)
459+
return _DeferredShardArg(x, sh, aval, device is not None, copy)
456460

457461

458462
def _device_put_impl(
@@ -501,12 +505,14 @@ def _batched_device_put_impl(
501505
copy_semantics: Sequence[CopySemantics]):
502506
ys = []
503507
shard_arg_indices, shard_arg_xs, shard_arg_shardings = [], [], []
508+
shard_arg_copy_semantics = []
504509
for i, (x, device, src, cp) in enumerate(zip(xs, devices, srcs, copy_semantics)):
505510
y = _device_put_impl(x, device=device, src=src, copy=cp)
506511
if isinstance(y, _DeferredShardArg):
507512
shard_arg_indices.append(i)
508513
shard_arg_xs.append(y.x)
509514
shard_arg_shardings.append(y.s)
515+
shard_arg_copy_semantics.append(y.copy_semantics)
510516
ys.append(y)
511517

512518
if shard_arg_xs:
@@ -515,7 +521,8 @@ def _batched_device_put_impl(
515521
# device_put handles `Layout` via a different path, so just pass `None` as
516522
# the layout here.
517523
shard_arg_results = pxla.shard_args(
518-
shard_arg_shardings, [None] * len(shard_arg_xs), shard_arg_xs)
524+
shard_arg_shardings, [None] * len(shard_arg_xs),
525+
shard_arg_copy_semantics, shard_arg_xs)
519526
for i, shard_arg_result in zip(shard_arg_indices, shard_arg_results):
520527
assert isinstance(ys[i], _DeferredShardArg)
521528
ys[i] = ys[i].result_handler(shard_arg_result)

jax/_src/earray.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,12 @@ def global_shards(self):
108108

109109
# TODO(mattjj): _set_array_base_attributes
110110

111-
def _earray_shard_arg_handler(xs, shardings, layouts):
111+
def _earray_shard_arg_handler(xs, shardings, layouts, copy_semantics):
112112
arrs = [x._data for x in xs]
113113
phys_shardings = [sharding_impls.physical_sharding(x.aval, sharding)
114114
for x, sharding in zip(xs, shardings)]
115115
# TODO(yashkatariya): `layouts` should be converted to physical layouts.
116-
return pxla.shard_args(phys_shardings, layouts, arrs)
116+
return pxla.shard_args(phys_shardings, layouts, copy_semantics, arrs)
117117
pxla.shard_arg_handlers[EArray] = _earray_shard_arg_handler
118118

119119
api_util._shaped_abstractify_handlers[EArray] = lambda self: self.aval

jax/_src/interpreters/pxla.py

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
from jax._src.interpreters import xla
6262
from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout
6363
from jax._src.lib import xla_client as xc
64+
from jax._src.lib import xla_extension_version
6465
from jax._src.lib.mlir import ir
6566
from jax._src.lib.mlir.dialects import hlo
6667
from jax._src.partition_spec import PartitionSpec
@@ -105,44 +106,69 @@ class WeakRefList(list):
105106

106107
### util
107108

109+
110+
def to_xc_copy_semantics(copy_semantics):
111+
if xla_extension_version < 296:
112+
return [None] * len(copy_semantics)
113+
out = []
114+
for cs in copy_semantics:
115+
if cs is None or cs == dispatch.CopySemantics.ALIAS:
116+
out.append(xc.ArrayCopySemantics.REUSE_INPUT)
117+
elif cs == dispatch.CopySemantics.COPY:
118+
out.append(xc.ArrayCopySemantics.ALWAYS_COPY)
119+
elif cs == dispatch.CopySemantics.DONATE:
120+
out.append(xc.ArrayCopySemantics.DONATE_INPUT)
121+
else:
122+
assert isinstance(cs, xc.ArrayCopySemantics)
123+
out.append(cs)
124+
return out
125+
126+
108127
def identity(x): return x
109128

110129
@profiler.annotate_function
111-
def shard_args(shardings: Sequence[JSharding], layouts, args,
112-
canonicalize=True) -> Sequence[xc.ArrayImpl]:
130+
def shard_args(shardings: Sequence[JSharding], layouts, copy_semantics,
131+
args, canonicalize=True) -> Sequence[xc.ArrayImpl]:
132+
xc_copy_semantics = to_xc_copy_semantics(copy_semantics)
133+
del copy_semantics
113134
# Fast path for one argument.
114135
if len(args) == 1:
115136
arg = args[0]
116137
if canonicalize:
117138
arg = xla.canonicalize_dtype(arg)
118-
return shard_arg_handlers[type(arg)]([arg], shardings, layouts)
119-
120-
# type(arg) -> (list[indices], list[args], list[shardings])
121-
batches = collections.defaultdict(lambda: ([], [], [], [])) # type: ignore
122-
for i, (arg, sharding, layout) in enumerate(safe_zip(args, shardings, layouts)):
139+
return shard_arg_handlers[type(arg)]([arg], shardings, layouts,
140+
xc_copy_semantics)
141+
142+
# type(arg) -> (list[indices], list[args], list[shardings], list[layouts],
143+
# list[copy_semantics])
144+
batches = collections.defaultdict(lambda: ([], [], [], [], [])) # type: ignore
145+
for i, (arg, sharding, layout, cs) in enumerate(
146+
safe_zip(args, shardings, layouts, xc_copy_semantics)):
123147
if canonicalize:
124148
arg = xla.canonicalize_dtype(arg)
125149
batch = batches[type(arg)]
126150
batch[0].append(i)
127151
batch[1].append(arg)
128152
batch[2].append(sharding)
129153
batch[3].append(layout)
154+
batch[4].append(cs)
130155

131156
# Call `shard_arg_handlers` per batch and build a flat list of arrays returned
132157
# from each call in the same order as `args`. Since `batches` is grouped by
133158
# types, we cannot simply flatten the results and we have to use the original
134159
# indices to put each array back to its original position.
135160
results: list[jax.Array | None] = [None] * len(args)
136-
for t, (indices, a, s, l) in batches.items():
137-
outs = shard_arg_handlers[t](a, s, l)
161+
for t, (indices, a, s, l, cs) in batches.items():
162+
outs = shard_arg_handlers[t](a, s, l, cs)
138163
for i, out in safe_zip(indices, outs):
139164
results[i] = out
140165
assert all(result is not None for result in results)
141166
return results
142167

143168

144169
shard_arg_handlers: dict[
145-
Any, Callable[[Sequence[Any], Sequence[Any], Sequence[Any]], Sequence[Any]]
170+
Any, Callable[[Sequence[Any], Sequence[Any], Sequence[Any], Sequence[Any]],
171+
Sequence[Any]]
146172
] = {}
147173

148174

@@ -172,12 +198,12 @@ def is_default_layout(curr_layout, sharding, aval):
172198
raise
173199

174200

175-
def _masked_array_error(xs, shardings, layouts):
201+
def _masked_array_error(xs, shardings, layouts, copy_semantics):
176202
raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. "
177203
"Use arr.filled() to convert the value to a standard numpy array.")
178204
shard_arg_handlers[np.ma.MaskedArray] = _masked_array_error
179205

180-
def _shard_np_array(xs, shardings, layouts):
206+
def _shard_np_array(xs, shardings, layouts, copy_semantics):
181207
results = []
182208
for x, sharding, layout in safe_zip(xs, shardings, layouts):
183209
devices = sharding._addressable_device_assignment
@@ -197,12 +223,12 @@ def _shard_np_array(xs, shardings, layouts):
197223
for _t in array_types:
198224
shard_arg_handlers[_t] = _shard_np_array
199225

200-
def _shard_darray(xs, shardings, layouts):
201-
return shard_args(shardings, layouts, [x._data for x in xs])
226+
def _shard_darray(xs, shardings, layouts, copy_semantics):
227+
return shard_args(shardings, layouts, copy_semantics, [x._data for x in xs])
202228
shard_arg_handlers[core.DArray] = _shard_darray
203229

204-
def _shard_mutable_array(xs, shardings, layouts):
205-
return shard_args(shardings, layouts, [x._buf for x in xs])
230+
def _shard_mutable_array(xs, shardings, layouts, copy_semantics):
231+
return shard_args(shardings, layouts, copy_semantics, [x._buf for x in xs])
206232
shard_arg_handlers[core.MutableArray] = _shard_mutable_array
207233

208234
def batched_device_put(aval: core.ShapedArray,
@@ -1135,7 +1161,8 @@ class InputsHandler:
11351161

11361162
def __init__(self, in_shardings, in_layouts, local_devices=None,
11371163
input_indices=None):
1138-
self.handler = partial(shard_args, in_shardings, in_layouts)
1164+
self.handler = partial(shard_args, in_shardings, in_layouts,
1165+
[None] * len(in_shardings))
11391166
self.in_shardings = in_shardings
11401167
self.in_layouts = in_layouts
11411168
self.local_devices = local_devices
@@ -3047,7 +3074,7 @@ def aot_cache_miss(*args, **kwargs):
30473074
JitGlobalCppCacheKeys(), tree_util.dispatch_registry, cc_shard_arg)
30483075

30493076
def cc_shard_arg(x, sharding, layout):
3050-
return shard_args([sharding], [layout], [x])[0]
3077+
return shard_args([sharding], [layout], [None], [x])[0]
30513078

30523079

30533080
def check_arg_avals_for_call(ref_avals, arg_avals,

jax/_src/lax/control_flow/loops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -704,7 +704,7 @@ def _maybe_put(x):
704704
aval = shaped_abstractify(x)
705705
s = sharding.SingleDeviceSharding(xb.local_devices(backend='cpu')[0])
706706
result_handler = pxla.global_aval_to_result_handler(aval, s, False)
707-
return result_handler(pxla.shard_args([s], [None], [x]))
707+
return result_handler(pxla.shard_args([s], [None], [None], [x]))
708708
else:
709709
return x
710710

0 commit comments

Comments
 (0)