Skip to content

Commit bc2b5d0

Browse files
Merge pull request #214 from ROCm/ci-upstream-sync-98_1
CI: 01/28/25 upstream sync
2 parents a366d41 + 47580ef commit bc2b5d0

33 files changed

+1092
-418
lines changed

.bazelrc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,10 @@ build:cuda --@local_config_cuda//:enable_cuda
124124
# Default hermetic CUDA and CUDNN versions.
125125
build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2"
126126
build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1"
127+
build:cuda --@local_config_cuda//cuda:include_cuda_libs=true
127128

128-
# This flag is needed to include CUDA libraries for bazel tests.
129-
test:cuda --@local_config_cuda//cuda:include_cuda_libs=true
129+
# This config is used for building targets with CUDA libraries from stubs.
130+
build:cuda_libraries_from_stubs --@local_config_cuda//cuda:include_cuda_libs=false
130131

131132
# Force the linker to set RPATH, not RUNPATH. When resolving dynamic libraries,
132133
# ld.so prefers in order: RPATH, LD_LIBRARY_PATH, RUNPATH. JAX sets RPATH to

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
1616

1717
## Unreleased
1818

19+
* New Features
20+
* Added an experimental {func}`jax.experimental.custom_dce.custom_dce`
21+
decorator to support customizing the behavior of opaque functions under
22+
JAX-level dead code elimination (DCE). See {jax-issue}`#25956` for more
23+
details.
24+
1925
## jax 0.5.0 (Jan 17, 2025)
2026

2127
As of this release, JAX now uses

build/build.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,7 @@ async def main():
532532

533533
if "cuda" in args.wheels:
534534
wheel_build_command_base.append("--config=cuda")
535+
wheel_build_command_base.append("--config=cuda_libraries_from_stubs")
535536
if args.use_clang:
536537
wheel_build_command_base.append(
537538
f"--action_env=CLANG_CUDA_COMPILER_PATH=\"{clang_path}\""

ci/utilities/setup_build_environment.sh

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,4 +75,26 @@ if [[ $(uname -s) =~ "MSYS_NT" ]]; then
7575
echo 'Converting MSYS Linux-like paths to Windows paths (for Bazel, Python, etc.)'
7676
# Convert all "JAXCI.*DIR" variables
7777
source <(python3 ./ci/utilities/convert_msys_paths_to_win_paths.py --convert $(env | grep "JAXCI.*DIR" | awk -F= '{print $1}'))
78-
fi
78+
fi
79+
80+
function retry {
81+
local cmd="$1"
82+
local max_attempts=3
83+
local attempt=1
84+
local delay=10
85+
86+
while [[ $attempt -le $max_attempts ]] ; do
87+
if eval "$cmd"; then
88+
return 0
89+
fi
90+
echo "Attempt $attempt failed. Retrying in $delay seconds..."
91+
sleep $delay # Prevent overloading
92+
93+
attempt=$((attempt + 1))
94+
done
95+
echo "$cmd failed after $max_attempts attempts."
96+
exit 1
97+
}
98+
99+
# Retry "bazel --version" 3 times to avoid flakiness when downloading bazel.
100+
retry "bazel --version"

docs/installation.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ The table below shows all supported platforms and installation options. Check if
3030

3131
| | Linux, x86_64 | Linux, aarch64 | Mac, x86_64 | Mac, aarch64 | Windows, x86_64 | Windows WSL2, x86_64 |
3232
|------------------|---------------------------------------|---------------------------------|---------------------------------------|---------------------------------------|--------------------------|------------------------------------------|
33-
| CPU | {ref}`yes <install-cpu>` | {ref}`yes <install-cpu>` | {ref}`yes <install-cpu>` | {ref}`yes <install-cpu>` | {ref}`yes <install-cpu>` | {ref}`yes <install-cpu>` |
33+
| CPU | {ref}`yes <install-cpu>` | {ref}`yes <install-cpu>` | {ref}`jax≤0.4.38 only <install-cpu>` | {ref}`yes <install-cpu>` | {ref}`yes <install-cpu>` | {ref}`yes <install-cpu>` |
3434
| NVIDIA GPU | {ref}`yes <install-nvidia-gpu>` | {ref}`yes <install-nvidia-gpu>` | no | n/a | no | {ref}`experimental <install-nvidia-gpu>` |
3535
| Google Cloud TPU | {ref}`yes <install-google-tpu>` | n/a | n/a | n/a | n/a | n/a |
3636
| AMD GPU | {ref}`experimental <install-amd-gpu>` | no | {ref}`experimental <install-mac-gpu>` | n/a | no | no |
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
``jax.experimental.custom_dce`` module
2+
======================================
3+
4+
.. automodule:: jax.experimental.custom_dce
5+
6+
API
7+
---
8+
9+
.. autosummary::
10+
:toctree: _autosummary
11+
12+
custom_dce
13+
custom_dce.def_dce

docs/jax.experimental.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ Experimental Modules
1616

1717
jax.experimental.checkify
1818
jax.experimental.compilation_cache
19+
jax.experimental.custom_dce
1920
jax.experimental.custom_partitioning
2021
jax.experimental.jet
2122
jax.experimental.key_reuse

jax/_src/core.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1731,14 +1731,20 @@ def _invalid_shape_error(shape: Shape, context: str=""):
17311731

17321732
return TypeError(msg)
17331733

1734+
def _make_lengths_same(sharding, ndim):
1735+
if ndim > len(sharding.spec):
1736+
return sharding.with_spec(sharding.spec._normalized_spec(ndim))
1737+
if ndim < len(sharding.spec):
1738+
return sharding.with_spec(sharding.spec[:ndim])
1739+
assert False, "unreachable"
1740+
1741+
17341742
# TODO(yashkatariya): Only works with User/Auto. Generalize it to work with
17351743
# Collective too.
17361744
def modify_spec_for_auto_manual(spec, mesh) -> P:
1737-
if all(s is None for s in spec):
1738-
return spec
17391745
new_spec = [] # type: ignore
17401746
for s in spec:
1741-
if s is None:
1747+
if not s:
17421748
new_spec.append(s)
17431749
else:
17441750
temp_s = s[0] if isinstance(s, tuple) else s
@@ -1748,22 +1754,29 @@ def modify_spec_for_auto_manual(spec, mesh) -> P:
17481754
else s)
17491755
return P(*new_spec)
17501756

1751-
def _maybe_modify_sharding(sharding):
1757+
def _maybe_modify_sharding(sharding, ndim):
17521758
if sharding.mesh._are_all_axes_explicit:
1753-
return sharding
1754-
new_spec = modify_spec_for_auto_manual(sharding.spec, sharding.mesh)
1755-
return sharding.with_spec(new_spec)
1759+
out = sharding
1760+
elif all(s is None for s in sharding.spec):
1761+
out = sharding
1762+
else:
1763+
out = sharding.with_spec(modify_spec_for_auto_manual(
1764+
sharding.spec, sharding.mesh))
1765+
if (len(out.spec) != ndim and
1766+
(out.mesh._are_all_axes_auto or out.mesh._are_all_axes_manual)):
1767+
out = _make_lengths_same(out, ndim)
1768+
return out
17561769

17571770

17581771
def get_sharding(sharding, ndim):
17591772
from jax._src.sharding_impls import NamedSharding # type: ignore
17601773

17611774
if sharding is not None:
1762-
if len(sharding.spec) != ndim:
1775+
out_s = _maybe_modify_sharding(sharding, ndim)
1776+
if len(out_s.spec) != ndim:
17631777
raise ValueError(
17641778
"Length of sharding.spec must be equal to aval's ndim. Got"
1765-
f" sharding.spec {sharding.spec} and aval.ndim {ndim}")
1766-
out_s = _maybe_modify_sharding(sharding)
1779+
f" sharding.spec {out_s.spec} and aval.ndim {ndim}")
17671780
else:
17681781
context_mesh = mesh_lib.get_abstract_mesh()
17691782
if not context_mesh:

jax/_src/custom_dce.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,9 @@ class custom_dce:
7575
... x * jnp.sin(y) if used_outs[1] else None,
7676
... )
7777
78-
In this example, ``used_outs`` is a ``tuple`` with two ``bool``s indicating
79-
which outputs are required. The DCE rule only computes the required outputs,
80-
replacing the unused outputs with ``None``.
78+
In this example, ``used_outs`` is a ``tuple`` with two ``bool`` values,
79+
indicating which outputs are required. The DCE rule only computes the
80+
required outputs, replacing the unused outputs with ``None``.
8181
8282
If the ``static_argnums`` argument is provided to ``custom_dce``, the
8383
indicated arguments are treated as static when the function is traced, and
@@ -108,12 +108,12 @@ def def_dce(
108108
109109
Args:
110110
dce_rule: A function that takes (a) any arguments indicated as static
111-
using ``static_argnums``, (b) a Pytree of ``bool``s (``used_outs``)
112-
indicating which outputs should be computed, and (c) the rest of the
113-
(non-static) arguments to the original function. The rule should return
114-
a Pytree with with the same structure as the output of the original
115-
function, but any unused outputs (as indicated by ``used_outs``) can be
116-
replaced with ``None``.
111+
using ``static_argnums``, (b) a Pytree of ``bool`` values
112+
(``used_outs``) indicating which outputs should be computed, and (c)
113+
the rest of the (non-static) arguments to the original function. The
114+
rule should return a Pytree with with the same structure as the output
115+
of the original function, but any unused outputs (as indicated by
116+
``used_outs``) can be replaced with ``None``.
117117
"""
118118
self.dce_rule = dce_rule
119119
return dce_rule

jax/_src/lax/lax.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2149,7 +2149,7 @@ def full_like(x: ArrayLike | DuckTypedArray,
21492149
if dtypes.issubdtype(dtype, dtypes.extended):
21502150
return dtype._rules.full(fill_shape, fill_value, dtype) # type: ignore[union-attr]
21512151

2152-
if (config.sharding_in_types.value and sharding is None and
2152+
if (config.sharding_in_types.value and sharding is None and shape is None and
21532153
isinstance(x, Array)):
21542154
sharding = x.aval.sharding
21552155
else:
@@ -4577,6 +4577,9 @@ def _clamp_shape_rule(min, operand, max):
45774577
f"(), got max.shape={max.shape}, {operand.shape=}.")
45784578
return operand.shape
45794579

4580+
def _clamp_sharding_rule(min, operand, max):
4581+
return operand.sharding
4582+
45804583
_clamp_dtype_rule = partial(naryop_dtype_rule, _input_dtype, [_any, _any, _any],
45814584
'clamp')
45824585

@@ -4617,7 +4620,8 @@ def _clamp_batch_rule(batched_args, batch_dims, **params):
46174620
x = broadcast(x, min.shape)
46184621
return clamp_p.bind(min, x, max), 0
46194622

4620-
clamp_p = standard_primitive(_clamp_shape_rule, _clamp_dtype_rule, 'clamp')
4623+
clamp_p = standard_primitive(_clamp_shape_rule, _clamp_dtype_rule, 'clamp',
4624+
sharding_rule=_clamp_sharding_rule)
46214625
ad.defjvp(clamp_p,
46224626
lambda g, min, operand, max:
46234627
select(bitwise_and(gt(min, operand), lt(min, max)),
@@ -5165,18 +5169,28 @@ def _rev_shape_rule(operand, *, dimensions):
51655169
raise TypeError(msg.format(dimensions, operand.ndim))
51665170
return operand.shape
51675171

5172+
def _rev_sharding_rule(operand, *, dimensions):
5173+
# TODO(yashkatariya): Will lead to data movement. Maybe just error out and
5174+
# require the operand to be unsharded?
5175+
return operand.sharding
5176+
51685177
def _rev_batch_rule(batched_args, batch_dims, *, dimensions):
51695178
operand, = batched_args
51705179
bdim, = batch_dims
51715180
new_dimensions = [i + 1 if i >= bdim else i for i in dimensions]
51725181
return rev(operand, new_dimensions), bdim
51735182

5174-
rev_p = standard_primitive(_rev_shape_rule, _input_dtype, 'rev')
5183+
rev_p = standard_primitive(_rev_shape_rule, _input_dtype, 'rev',
5184+
sharding_rule=_rev_sharding_rule)
51755185
ad.deflinear2(rev_p, lambda t, _, dimensions: [rev(t, dimensions)])
51765186
batching.primitive_batchers[rev_p] = _rev_batch_rule
51775187

51785188
def _rev_lower(ctx, x, *, dimensions):
5179-
return [hlo.reverse(x, mlir.dense_int_array(dimensions))]
5189+
aval_out, = ctx.avals_out
5190+
out = hlo.reverse(x, mlir.dense_int_array(dimensions))
5191+
if config.sharding_in_types.value:
5192+
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
5193+
return [out]
51805194
mlir.register_lowering(rev_p, _rev_lower)
51815195

51825196

@@ -5932,7 +5946,10 @@ def _sort_lower(ctx, *operands, dimension, is_stable, num_keys):
59325946
mlir.flatten_ir_values(operands),
59335947
dimension=mlir.i64_attr(dimension),
59345948
is_stable=ir.BoolAttr.get(is_stable))
5935-
scalar_avals = [aval.update(shape=()) for aval in ctx.avals_in]
5949+
scalar_s = (lambda a: a.sharding.with_spec(P())
5950+
if config.sharding_in_types.value else lambda _: None)
5951+
scalar_avals = [aval.update(shape=(), sharding=scalar_s(aval))
5952+
for aval in ctx.avals_in]
59365953
scalar_types = safe_map(mlir.aval_to_ir_type, scalar_avals)
59375954
comparator = sort.comparator.blocks.append(
59385955
*util.flatten(zip(scalar_types, scalar_types)))

0 commit comments

Comments
 (0)